"""Preprocess the config file for the cosimulation."""

__copyright__ = """
Copyright (c) 2024 RapidStream Design Automation, Inc. and contributors.
All rights reserved. The contributor(s) of this file has/have agreed to the
RapidStream Contributor License Agreement.
"""

import glob
import json
import logging
import os
import shutil
import sys
import zipfile
from pathlib import Path
from xml.etree import ElementTree as ET

from tapa.cosim.common import Arg, Port

_logger = logging.getLogger().getChild(__name__)


def _update_relative_path(config: dict, config_path: str) -> None:
    """convert the relative path in the config file"""
    config_dir = "/".join(config_path.split("/")[:-1])

    curr_path = config["xo_path"]
    if not curr_path.startswith("/") and not curr_path.startswith("~"):
        config["xo_path"] = f"{config_dir}/{curr_path}"

    for axi_name, curr_path in config["axi_to_data_file"].items():
        if not curr_path.startswith("/") and not curr_path.startswith("~"):
            config["axi_to_data_file"][axi_name] = f"{config_dir}/{curr_path}"


def extract_part_from_xml_file(file_path: str) -> str:
    """
    Extracts the <Part> element from the given XML file.

    Args:
        file_path (str): The path to the XML file.
    """
    try:
        tree = ET.parse(file_path)
    except (ET.ParseError, FileNotFoundError) as e:
        return f"Error reading or parsing the XML file: {e}"

    root = tree.getroot()
    part_element = root.find(".//Part")

    if part_element is None:
        msg = "The XML file does not contain a <Part> element."
        raise ValueError(msg)

    return part_element.text


def parse_part_num(xo_dir: str) -> str | None:
    """Extract the part number from the xo directory."""
    report_dir = f"{xo_dir}/report"
    if not os.path.exists(report_dir):
        _logger.warning(
            "The report directory %s does not exist inside xo, check if the "
            "xo is generated by the latest TAPA",
            report_dir,
        )
        return None

    csynth_reports = glob.glob(f"{report_dir}/*_csynth.xml")
    if not csynth_reports:
        _logger.warning(
            "No csynth report found in %s, check if the xo is generated by "
            "the latest TAPA",
            report_dir,
        )
        return None

    return extract_part_from_xml_file(csynth_reports[0])


def _parse_xo_update_config(config: dict, tb_output_dir: str) -> None:
    """
    Only supports TAPA xo. Vitis XO has different hierarchy and RTL coding style
    """
    xo_path = config["xo_path"]

    tmp_path = f"{tb_output_dir}/tapa_fast_cosim_{os.getuid()}/"
    shutil.rmtree(tmp_path, ignore_errors=True)
    Path(tmp_path).mkdir(parents=True, exist_ok=True)
    shutil.copy(xo_path, f"{tmp_path}/target.xo")
    zip_ref = zipfile.ZipFile(f"{tmp_path}/target.xo", "r")
    zip_ref.extractall(tmp_path)

    # only supports tapa xo
    src_dirs = glob.glob(f"{tmp_path}/ip_repo/*/src")
    assert len(src_dirs) == 1, "Only supports TAPA XO. Vitis XO is not supported"
    config["verilog_path"] = src_dirs[0]

    # extract other kernel information
    kernel_file_path = glob.glob(f"{tmp_path}/*/kernel.xml")[0]
    kernel_xml = ET.parse(kernel_file_path).getroot().find("./kernel")
    if kernel_xml is None:
        _logger.error("Fail to extract kernel name")
        sys.exit(1)
    config["top_name"] = kernel_xml.attrib["name"]

    # parse kernel ports and args
    ports = {}
    for port_xml in kernel_xml.findall("./ports/port"):
        port = Port(
            name=port_xml.attrib["name"],
            mode=port_xml.attrib["mode"],
            data_width=int(port_xml.attrib["dataWidth"]),
        )
        ports[port.name] = port
        _logger.debug("port: %s", port)
    args = []
    for arg_xml in kernel_xml.findall("./args/arg"):
        arg = Arg(
            name=arg_xml.attrib["name"],
            address_qualifier=int(arg_xml.attrib["addressQualifier"]),
            id=int(arg_xml.attrib["id"]),
            port=ports[arg_xml.attrib["port"]],
        )
        args.append(arg)
        _logger.debug("arg: %s", arg)
    config["args"] = args

    # convert argument index in the config file to actual names
    id_to_name = {arg.id: arg.name for arg in args}

    # update scalar arguments
    def change_id_to_name(id_to_val: dict[str, str]) -> dict[str, str]:
        return {
            id_to_name[int(scalar_arg_id)]: val
            for scalar_arg_id, val in id_to_val.items()
        }

    for entry in (
        "scalar_to_val",
        "axi_to_data_file",
        "axis_to_data_file",
        "axi_to_c_array_size",
    ):
        config[entry] = change_id_to_name(config[entry] or {})

    config["part_num"] = parse_part_num(tmp_path)


def _check_scalar_val_format(config: dict) -> None:
    for scalar, val in config["scalar_to_val"].items():
        assert val.startswith("'h"), (
            "scalar value should be written in hex format, lsb on the right, "
            "with the suffix 'h according to Verilog syntax. "
            f"Violation: {scalar}: {val}"
        )
        assert len(val) <= 2 + 16, (
            f"scalar value should be at most 64 bit. Violation: {scalar}: {val}"
        )


def preprocess_config(
    config_path: str, tb_output_dir: str, part_num: str | None
) -> dict:
    """Preprocess the config file."""
    with open(config_path, encoding="utf-8") as fp:
        config = json.load(fp)

    # handle designs with no scalar
    if "scalar_to_val" not in config:
        config["scalar_to_val"] = {}

    _update_relative_path(config, config_path)
    _parse_xo_update_config(config, tb_output_dir)
    _check_scalar_val_format(config)

    # overwrite part number if provided
    if part_num:
        config["part_num"] = part_num

    return config
